{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Vanilla GAN with Multi GPUs + Naming Layers using OrderedDict\n",
    "# Code by GunhoChoi\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils as utils\n",
    "import torch.nn.init as init\n",
    "from torch.autograd import Variable\n",
    "import torchvision.utils as v_utils\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Set Hyperparameters\n",
    "# change num_gpu to the number of gpus you want to use\n",
    "\n",
    "epoch = 1000\n",
    "batch_size = 200\n",
    "learning_rate = 0.0002\n",
    "num_gpus = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded\n"
     ]
    }
   ],
   "source": [
    "# Download Data\n",
    "\n",
    "mnist_train = dset.MNIST(\"./\", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)\n",
    "\n",
    "# Set Data Loader(input pipeline)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset=mnist_train,batch_size=batch_size,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Generator receives random noise z and create 1x28x28 image\n",
    "# we can name each layer using OrderedDict\n",
    "\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Generator,self).__init__()\n",
    "        self.layer1 = nn.Linear(100,7*7*256)\n",
    "        self.layer2 = nn.Sequential(OrderedDict([\n",
    "                        ('conv1', nn.ConvTranspose2d(256,128,3,2,1,1)),\n",
    "                        ('relu1', nn.LeakyReLU()),\n",
    "                        ('bn1', nn.BatchNorm2d(128)),\n",
    "                        ('conv2', nn.ConvTranspose2d(128,64,3,1,1)),\n",
    "                        ('relu2', nn.LeakyReLU()),\n",
    "                        ('bn2', nn.BatchNorm2d(64))\n",
    "            ]))\n",
    "        self.layer3 = nn.Sequential(OrderedDict([\n",
    "                        ('conv3',nn.ConvTranspose2d(64,16,3,1,1)),\n",
    "                        ('relu3',nn.LeakyReLU()),\n",
    "                        ('bn3',nn.BatchNorm2d(16)),\n",
    "                        ('conv4',nn.ConvTranspose2d(16,1,3,2,1,1)),\n",
    "                        ('relu4',nn.LeakyReLU())\n",
    "            ]))\n",
    "\n",
    "    def forward(self,z):\n",
    "        out = self.layer1(z)\n",
    "        out = out.view(batch_size//num_gpus,256,7,7)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Discriminator receives 1x28x28 image and returns a float number 0~1\n",
    "# we can name each layer using OrderedDict\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Discriminator,self).__init__()\n",
    "        self.layer1 = nn.Sequential(OrderedDict([\n",
    "                        ('conv1',nn.Conv2d(1,16,3,padding=1)),   # batch x 16 x 28 x 28\n",
    "                        ('relu1',nn.LeakyReLU()),\n",
    "                        ('bn1',nn.BatchNorm2d(16)),\n",
    "                        ('conv2',nn.Conv2d(16,32,3,padding=1)),  # batch x 32 x 28 x 28\n",
    "                        ('relu2',nn.LeakyReLU()),\n",
    "                        ('bn2',nn.BatchNorm2d(32)),\n",
    "                        ('max1',nn.MaxPool2d(2,2))   # batch x 32 x 14 x 14\n",
    "        ]))\n",
    "        self.layer2 = nn.Sequential(OrderedDict([\n",
    "                        ('conv3',nn.Conv2d(32,64,3,padding=1)),  # batch x 64 x 14 x 14\n",
    "                        ('relu3',nn.LeakyReLU()),\n",
    "                        ('bn3',nn.BatchNorm2d(64)),\n",
    "                        ('max2',nn.MaxPool2d(2,2)),\n",
    "                        ('conv4',nn.Conv2d(64,128,3,padding=1)),  # batch x 128 x 7 x 7\n",
    "                        ('relu4',nn.LeakyReLU())\n",
    "        ]))\n",
    "        self.fc = nn.Sequential(\n",
    "                        nn.Linear(128*7*7,1),\n",
    "                        nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self,x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.view(batch_size//num_gpus, -1)\n",
    "        out = self.fc(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Put class objects on Multiple GPUs using \n",
    "# torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)\n",
    "# device_ids: default all devices / output_device: default device 0 \n",
    "# along with .cuda()\n",
    "\n",
    "generator = nn.DataParallel(Generator()).cuda()\n",
    "discriminator = nn.DataParallel(Discriminator()).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module.layer1.weight\n",
      "module.layer1.bias\n",
      "module.layer2.conv1.weight\n",
      "module.layer2.conv1.bias\n",
      "module.layer2.bn1.weight\n",
      "module.layer2.bn1.bias\n",
      "module.layer2.bn1.running_mean\n",
      "module.layer2.bn1.running_var\n",
      "module.layer2.conv2.weight\n",
      "module.layer2.conv2.bias\n",
      "module.layer2.bn2.weight\n",
      "module.layer2.bn2.bias\n",
      "module.layer2.bn2.running_mean\n",
      "module.layer2.bn2.running_var\n",
      "module.layer3.conv3.weight\n",
      "module.layer3.conv3.bias\n",
      "module.layer3.bn3.weight\n",
      "module.layer3.bn3.bias\n",
      "module.layer3.bn3.running_mean\n",
      "module.layer3.bn3.running_var\n",
      "module.layer3.conv4.weight\n",
      "module.layer3.conv4.bias\n"
     ]
    }
   ],
   "source": [
    "# Get parameter list by using class.state_dict().keys()\n",
    "\n",
    "gen_params = generator.state_dict().keys()\n",
    "dis_params = discriminator.state_dict().keys()\n",
    "\n",
    "for i in gen_params:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# loss function, optimizers, and labels for training\n",
    "\n",
    "loss_func = nn.BCELoss()\n",
    "gen_optim = torch.optim.Adam(generator.parameters(), lr=learning_rate)\n",
    "dis_optim = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)\n",
    "\n",
    "ones_label = Variable(torch.ones(batch_size,1)).cuda()\n",
    "zeros_label = Variable(torch.zeros(batch_size,1)).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--------model not restored--------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# model restore if any\n",
    "\n",
    "try:\n",
    "    generator, discriminator = torch.load('./model/vanilla_gan.pkl')\n",
    "    print(\"\\n--------model restored--------\\n\")\n",
    "except:\n",
    "    print(\"\\n--------model not restored--------\\n\")\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Generator. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n",
      "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Discriminator. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0th iteration gen_loss: \n",
      " 0.6704\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      " dis_loss: \n",
      " 1.3712\n",
      "[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# train\n",
    "\n",
    "for i in range(epoch):\n",
    "    for j,(image,label) in enumerate(train_loader):\n",
    "        image = Variable(image).cuda()\n",
    "        \n",
    "        # generator\n",
    "        for k in range(5):\n",
    "            z = Variable(torch.rand(batch_size,100)).cuda()\n",
    "            gen_optim.zero_grad()\n",
    "            gen_fake = generator.forward(z)\n",
    "            dis_fake = discriminator.forward(gen_fake)\n",
    "            gen_loss = torch.sum(loss_func(dis_fake,ones_label)) # fake classified as real\n",
    "            gen_loss.backward(retain_variables=True)\n",
    "            gen_optim.step()\n",
    "    \n",
    "        # discriminator\n",
    "        dis_optim.zero_grad()\n",
    "        dis_real = discriminator.forward(image)\n",
    "        dis_loss = torch.sum(loss_func(dis_fake,zeros_label)) + torch.sum(loss_func(dis_real,ones_label))\n",
    "        dis_loss.backward()\n",
    "        dis_optim.step()\n",
    "    \n",
    "    # model save\n",
    "    if i % 5 == 0:\n",
    "        torch.save([generator,discriminator],'./model/vanilla_gan.pkl')\n",
    "\n",
    "    print(\"{}th iteration gen_loss: {} dis_loss: {}\".format(i,gen_loss.data,dis_loss.data))\n",
    "    v_utils.save_image(gen_fake.data[0:20],\"./gan_result/gen_{}.png\".format(i), nrow=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
